import torch
import torch.nn as nn
from tqdm import tqdm
from datasets import load_dataset

def evaluate_perplexity(model, tokenizer):
    def _perplexity(nlls, n_samples, seqlen):
        return torch.exp(torch.stack(nlls).sum() / (n_samples * seqlen))
    
    # load and prepare dataset
    data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
    data = tokenizer("\n\n".join(data['text']), return_tensors='pt')
    data = data.input_ids.to(model.device)

    seqlen = 2048
    model = model.eval()
    n_samples = data.numel() // seqlen
    
    nlls = []

    with tqdm(range(n_samples), desc="Perplexity -") as progress_bar:
        for i in progress_bar:
            start_index = (i * seqlen)
            end_index = ((i + 1) * seqlen)
            batch = data[:, start_index:end_index].to(model.device)
            with torch.no_grad():
                logits = model(batch).logits
            shift_logits = logits[:, :-1, :].contiguous().float()
            shift_labels = data[:, start_index:end_index][:, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            neg_log_likelihood = loss.float() * seqlen
            nlls.append(neg_log_likelihood)

            curr_ppl = _perplexity(nlls, i+1, seqlen)
            progress_bar.set_description(f"Perplexity {curr_ppl:.3f}")

    ppl = _perplexity(nlls, n_samples, seqlen)
    
    return ppl.item()

if __name__ == '__main__':
    from transformers import AutoModelForCausalLM, AutoTokenizer
    
    model_path = 'mistralai/Mistral-7B-Instruct-v0.1'
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    evaluate_perplexity(model, tokenizer)
